Skip to content

Conversation

@nhat-nguyen
Copy link
Contributor

@nhat-nguyen nhat-nguyen commented Jan 2, 2025

This PR introduces the triton-to-unstructured pass which is the first step towards allowing triton-shared to compile pointer sequences that cannot be analyzed by triton-to-structured (gather / scatter).

This pass attempts to lower all loads and stores of unstructured pointers to
tts.gather or tts.scatter that take a single base, a tensor of offsets, an
optional tensor of mask values, and a default value in case of load.

In addition, all pointer-producing ops will be eliminated and replaced by
offset-producing ops. tts.gather and tts.scatter will use the pointer
directly from the kernel arguments as opposed to pointer produced by ops such
as tt.addptr and tt.splat.

Example:

module {
  tt.func public @gather_simple_no_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
    %cst = arith.constant dense<5> : tensor<64xi32>
    %cst_0 = arith.constant dense<10> : tensor<64xi32>
    %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %1 = arith.divsi %0, %cst_0 : tensor<64xi32>
    %2 = arith.addi %1, %cst : tensor<64xi32>
    %3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
    %4 = tt.addptr %3, %2 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
    %5 = tt.load %4 : tensor<64x!tt.ptr<f32>>
    %6 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
    %7 = tt.addptr %6, %0 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
    tt.store %7, %5 : tensor<64x!tt.ptr<f32>>
    tt.return
  }
}

becomes

module {
  tt.func public @gather_simple_no_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
    %cst = arith.constant dense<5> : tensor<64xi32>
    %cst_0 = arith.constant dense<10> : tensor<64xi32>
    %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %1 = arith.divsi %0, %cst_0 : tensor<64xi32>
    %2 = arith.addi %1, %cst : tensor<64xi32>
    %3 = tts.gather %arg0[%2] : (<f32>, tensor<64xi32>) -> tensor<64xf32>
    tts.scatter %3 into %arg1[%0] : tensor<64xf32> into (<f32>, tensor<64xi32>)
    tt.return
  }
}

Current assumptions and limitations:

  • For simplicity, the pass assumes that gather / scatter operations load /
    store from / to a single base with a tensor of random offsets. As a
    result, the following triton program would not work:
@triton.jit
def gather_simple(in0, in1, out0):
    offs = tl.arange(0, 8)
    in0_ptrs = in0 + offs
    in1_ptrs = in1 + offs
    ptrs = tl.cat(in0_ptrs, in1_ptrs, can_reorder=True)
    c = tl.load(ptrs)
    out_offs = tl.arange(0, 16)
    tl.store(out0 + out_offs, c)

In the above program, ptrs contains 2 bases: in0 and in1 after the
cat operation.

For more details on the algorithm, see the TritonToUnstructuredPass.cpp file.

Future work

Future work may include scaling the algorithm to support multiple bases -- one
possible solution is to let tts.gather and tts.scatter take in an additional
tensor of base pointers corresponding to the tensor of offsets. But because
we do not want pointer-producing ops to be present after this pass, we can
use a tensor of index where each element indicates the index of the pointer
argument to be used. The drawback is a gather or scatter operation now needs
one extract lookup to get the base which will affect performance.


Intended lowering pipeline

  • triton-to-structured (no changes):
    • analyzes structured addptr sequences
      • introduces tts.make_tptr %ptr_arg with offsets and strides
      • introduces tts.load and tts.store
    • leaves unstructured addptr sequences and their corresponding tt.load and tt.store intact
  • triton-to-unstructured (Introduce triton-to-unstructured pass #210):
    • introduces tts.gather and tts.scatter
    • removes all pointer-producing ops such as tt.addptr and tt.splat and replaces them with offset-producing ops
  • structured-to-memref (Update structured-to-memref pass to support the new pass pipeline #217):
    • currently converts everything to memref including scalar addptr and kernel arguments
    • will change to just convert ops in the tts dialect to memref with the exception of tts.gather and tts.scatter
  • unstructured-to-memref (Introduce unstructured-to-memref pass #216):
    • converts the remaining unstructured tts.gather, tts.scatter into memref
  • triton-ptr-to-memref (Introduce triton-ptr-to-memref pass #211):
    • converts kernel arguments with pointer type to memref

@nhat-nguyen nhat-nguyen changed the title Introduce fold-unstructured-triton-ptr pass Introduce fold-unstructured-ptr pass Jan 3, 2025
@nhat-nguyen nhat-nguyen marked this pull request as ready for review January 6, 2025 19:21
@nhat-nguyen nhat-nguyen changed the title Introduce fold-unstructured-ptr pass Introduce triton-to-unstructured pass Jan 10, 2025
@nhat-nguyen nhat-nguyen marked this pull request as ready for review January 10, 2025 20:20
@haishanzzzz
Copy link
Contributor

Thank you for the PR! I have a few general questions:

  1. The transformation this pass performs seems to be targeting hardware that uses "base ptr + a vector of offset" in performing gather/scatter, while GPU-like architectures simply create a tensor of addresses. Is that correct?

  2. Would there still be kernels that can't be lowered to this pattern, even though very unlikely? E.g., would it be possible to have a vector of int64s from memory, convert them to pointers, and use it to perform ld/st?

  3. You mentioned one future work is "scaling the algorithm to support multiple bases". I thought it is a very promising direction. Do you have specific ideas / timeline in mind? Would love to discuss more offline.

@kile01
Copy link
Contributor

kile01 commented Jan 13, 2025

I may have missed it, but do you have any examples of the tts.gather with a default value?

@nhat-nguyen
Copy link
Contributor Author

3. You mentioned one future work is "scaling the algorithm to support multiple bases". I thought it is a very promising direction. Do you have specific ideas / timeline in mind? Would love to discuss more offline.

Oh I do have these but in another PR, let me copy them over to this PR too. Thanks for noticing :D

Copy link
Contributor

@kile01 kile01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link

@beicy beicy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@nhat-nguyen nhat-nguyen merged commit 986cea8 into main Jan 14, 2025
3 checks passed
@nhat-nguyen nhat-nguyen deleted the nhat/fold_unstructured branch January 14, 2025 19:01
nhat-nguyen added a commit that referenced this pull request Jan 14, 2025
This PR introduces the `triton-ptr-to-memref` pass responsible for
converting function signature that uses triton ptr to use memref
instead. This is part of the work to allow triton-shared to lower gather
/ scatter pointer sequences.

Much of this code is copied from the current `StructuredToMemref` pass
which will be cleaned up in a later PR.

---

# Intended lowering pipeline
- triton-to-structured (no changes):
    - analyzes structured addptr sequences
        - introduces `tts.make_tptr %ptr_arg with offsets and strides`
        - introduces `tts.load` and `tts.store`
- leaves unstructured addptr sequences and their corresponding `tt.load`
and `tt.store` intact
- triton-to-unstructured (#210):
    - introduces `tts.gather` and `tts.scatter`
- removes all pointer-producing ops such as `tt.addptr` and `tt.splat`
and replaces them with offset-producing ops
- structured-to-memref (#217):
- currently converts everything to memref including scalar addptr and
kernel arguments
- will change to just convert ops in the `tts` dialect to `memref` with
the exception of `tts.gather` and `tts.scatter`
- unstructured-to-memref (#216):
- converts the remaining unstructured `tts.gather`, `tts.scatter` into
memref
- triton-ptr-to-memref (#211):
    - converts kernel arguments with pointer type to memref
nhat-nguyen added a commit that referenced this pull request Jan 15, 2025
This PR introduces the `unstructured-to-memref` pass responsible for
converting unstructured triton load / store ops to memref load / store
ops. This is part of the work to allow triton-shared to lower gather /
scatter pointer sequences. The pass is intended to be used after running
`--fold-unstructured-ptr`.

Triton load op (gather) is lowered to a `linalg.generic` whose body
contains a load from the offset indicated by the offset provided by
`tts.make_unstructured_tptr`. For load op with mask, an inner-most
`scf.if` is used to return a default value (or the `other` in `tt.load`
if provided) if the corresponding mask value is false.

Example of a load:

```mlir
  func.func @gather_simple_mask_with_other(%arg0: memref<*xf32>, %arg1: memref<*xf32>) {
      %cst = arith.constant -1.000000e+00 : f32
      %cast = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
      %load_tensor = bufferization.to_tensor %cast restrict : memref<?xf32>
      %out = tensor.empty() : tensor<64xf32>
      %gather = linalg.generic {
        iterator_types = ["parallel"]
      } ins(%offset_tensor, %mask_tensor : tensor<64xi32>, tensor<64xi1>)
        outs(%out : tensor<64xf32>) {
      ^bb0(%offset: i32, %mask: i1, %out: f32):
        %yield = scf.if %mask -> (f32) {
          %index = arith.index_cast %offset : i32 to index
          %extracted = tensor.extract %load_tensor[%index] : tensor<?xf32>
          scf.yield %extracted : f32
        } else {
          scf.yield %cst : f32
        }
        linalg.yield %yield : f32
      } -> tensor<64xf32>
```

Triton store op (scatter) is lowered to an `affine.for` loop nest that
stores the value to the appropriate offset provided by
`tts.make_unstructured_tptr`. Store op with mask is also supported.

Example of a store:

```mlir
  func.func @masked_gather_scatter(%arg0: memref<*xf32>, %arg1: memref<*xf32>) {
    %store_memref = memref.cast %arg1 : memref<*xf32> to memref<?xf32>
    affine.for %i = 0 to 4 {
      %mask_val = tensor.extract %mask[%i] : tensor<4xi1>
      scf.if %mask_val {
        %offset_val = tensor.extract %offset_tensor[%i] : tensor<4xi32>
        %store_value = tensor.extract %tensor[%i] : tensor<4xf32>
        %offset_index = arith.index_cast %offset_val : i32 to index
        memref.store %store_value, %store_memref[%offset_index] : memref<?xf32>
      }
    }
```

---

# Intended lowering pipeline
- triton-to-structured (no changes):
    - analyzes structured addptr sequences
        - introduces `tts.make_tptr %ptr_arg with offsets and strides`
        - introduces `tts.load` and `tts.store`
- leaves unstructured addptr sequences and their corresponding `tt.load`
and `tt.store` intact
- triton-to-unstructured (#210):
    - introduces `tts.gather` and `tts.scatter`
- removes all pointer-producing ops such as `tt.addptr` and `tt.splat`
and replaces them with offset-producing ops
- structured-to-memref (#217):
- currently converts everything to memref including scalar addptr and
kernel arguments
- will change to just convert ops in the `tts` dialect to `memref` with
the exception of `tts.gather` and `tts.scatter`
- unstructured-to-memref (#216):
- converts the remaining unstructured `tts.gather`, `tts.scatter` into
memref
- triton-ptr-to-memref (#211):
    - converts kernel arguments with pointer type to memref
nhat-nguyen added a commit that referenced this pull request Jan 16, 2025
…217)

This PR simplifies the `structured-to-memref` pass responsible for
converting structured triton load / store ops to memref load / store
ops. This is part of the work to allow triton-shared to lower gather /
scatter pointer sequences. Previously, this pass is also responsible for
converting scalar pointer load and store into memref; that
transformation has now been moved to `unstructured-to-memref`.

In addition, the PR also updates the `triton-to-linalg-experimental`
pass to fully utilize all the new passes. Once merged, triton-shared now
fully supports gather / scatter. An example test
(`test_gather_scatter.py`) is also added to demonstrate this new
capability.

---

# Intended lowering pipeline
- triton-to-structured (no changes):
    - analyzes structured addptr sequences
        - introduces `tts.make_tptr %ptr_arg with offsets and strides`
        - introduces `tts.load` and `tts.store`
- leaves unstructured addptr sequences and their corresponding `tt.load`
and `tt.store` intact
- triton-to-unstructured (#210):
    - introduces `tts.gather` and `tts.scatter`
- removes all pointer-producing ops such as `tt.addptr` and `tt.splat`
and replaces them with offset-producing ops
- structured-to-memref (#217):
- currently converts everything to memref including scalar addptr and
kernel arguments
- will change to just convert ops in the `tts` dialect to `memref` with
the exception of `tts.gather` and `tts.scatter`
- unstructured-to-memref (#216):
- converts the remaining unstructured `tts.gather`, `tts.scatter` into
memref
- triton-ptr-to-memref (#211):
    - converts kernel arguments with pointer type to memref
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants